import numpy as np
import torch 
from torch.utils.data import Dataset
import torchvision.transforms as tfs
import cv2
from PIL import Image
import pandas as pd
import glob
class TIMG_val(Dataset):
    def __init__(self, 
                 dir_path,
                 label_dict,
                 image_size=64,
                 seed=123):
        self.df = pd.read_csv(dir_path+'/val/val_annotations.txt', delimiter='\t',header=None) 
        self._num_images = len(self.df)
        self.label_dict = label_dict
        paths = glob.glob(dir_path+'val/images/val_*.JPEG')
        self.image_size = image_size
        self._images_list = paths
        self.name2label = {}
        for i in range(self._num_images):
          self.name2label[self.df.iloc[i,0]]=self.df.iloc[i,1]

    @property  
    def data_size(self):
        return self._num_images 
    def __len__(self):
        return self._num_images
    
    def __getitem__(self, idx):
        path = self._images_list[idx]
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR).astype(np.float32)  
        #image = image/255.0

        #image = image.transpose((2, 0, 1)).astype(np.float32)
        image = np.expand_dims(image,axis=0)
        name = path.split('/')[-1]
        label = self.label_dict[self.name2label[name]]
        label = np.expand_dims(label, axis=0)
        return image, label
        
    
class TIMG(Dataset):
    def __init__(self, 
                 dir_path,
                 image_size=64,
                 seed=123,
                 verbose=True):
        paths = glob.glob(dir_path+'train/n*')
        label_dict = {}
        for i in range(len(paths)):
          name = paths[i].split('/')[-1]
          label_dict[name] = i
        print(label_dict)
        self.label_dict = label_dict
        paths = glob.glob(dir_path+'train/n*/images/n*.JPEG')
        # load data from csv
        self._num_images = len(paths)
        self.image_size = image_size
        
        self._images_list = paths
        
    @property  
    def data_size(self):
        return self._num_images 
    def __len__(self):
        return self._num_images
    
    def __getitem__(self, idx):
        path = self._images_list[idx]
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR).astype(np.float32)  
        #image = image/255.0

        #image = image.transpose((2, 0, 1)).astype(np.float32)
        image = np.expand_dims(image,axis=0)
        name = path.split('/')[-1]
        name = name.split('_')[0]
        label = self.label_dict[name]
        label = np.expand_dims(label, axis=0)
        return image, label


if __name__ == '__main__':
    root = '/dual_data/not_backed_up/tiny-imagenet-200/'
    traindSet = TIMG(dir_path=root, image_size=64)    
    trainloader =  torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, drop_last=False, shuffle=False)
    trX = []
    trY = []
    # convert jpgs to binary file.
    for idx, data in enumerate(traindSet):
      train_data, train_label = data
      trX.append(train_data)
      trY.append(train_label)
      if idx%100 == 0:
        print(idx)
    trX = np.concatenate((trX), axis=0)
    trY = np.concatenate((trY), axis=0)
    np.save('/home/dixzhu/source/models-master/research/resnet/data/TIMG_train_X',trX)
    np.save('/home/dixzhu/source/models-master/research/resnet/data/TIMG_train_Y',trY)
    print(trX.shape)
    print(trY.shape)
    valSet = TIMG_val(dir_path=root, label_dict=traindSet.label_dict, image_size=64)
    valloader =  torch.utils.data.DataLoader(valSet, batch_size=32, num_workers=2, drop_last=False, shuffle=False)
    vaX = []
    vaY = []
    # convert jpgs to binary file.
    for idx, data in enumerate(valSet):
      val_data, val_label = data
      vaX.append(val_data)
      vaY.append(val_label)
      if idx%100 == 0:
        print(idx)
    vaX = np.concatenate((vaX), axis=0)
    vaY = np.concatenate((vaY), axis=0)
    np.save('/home/dixzhu/source/models-master/research/resnet/data/TIMG_test_X',vaX)
    np.save('/home/dixzhu/source/models-master/research/resnet/data/TIMG_test_Y',vaY)
    print(vaX.shape)
    print(vaY.shape)
    
